import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from pydicom import dcmread
from torch.utils.data import Dataset
from lib.mask_functions import rle2mask
class TrainingSet(Dataset):
def __init__(self):
self.df = pd.read_csv('./dataset/train.csv')
def __getitem__(self, idx):
imageId, rle = self.df['ImageId'][idx], self.df['EncodedPixels'][idx]
dcm = dcmread('./dataset/dicom-images-train/{}.dcm'.format(imageId))
image, height, width = dcm.pixel_array, dcm.Rows, dcm.Columns
mask = rle2mask(rle, width, height).T if rle != '-1' else np.zeros(width * height).reshape(width, height)
return (image / 255), (mask / 255)
def __len__(self):
return len(self.df)
def training_samples(n = 4, m = 6):
fig, axes = plt.subplots(n, m, figsize = (m * 5, n * 5))
idx = 0
for id, (image, mask) in enumerate(TrainingSet()):
if mask.max() > 0:
axes[(idx // m), (idx % m)].set_title('Train {}'.format(id))
axes[(idx // m), (idx % m)].imshow(image, cmap = plt.cm.bone)
axes[(idx // m), (idx % m)].imshow(mask, alpha = 0.3, cmap = 'Reds')
idx = idx + 1
if idx == n * m:
break
training_samples()
class TestSet(Dataset):
def __init__(self):
self.df = pd.read_csv('./dataset/test.csv')
def __getitem__(self, idx):
imageId = self.df['ImageId'][idx]
dcm = dcmread('./dataset/dicom-images-test/{}.dcm'.format(imageId))
image = dcm.pixel_array
return image / 255
def __len__(self):
return len(self.df)
def test_samples(n = 4, m = 6):
fig, axes = plt.subplots(n, m, figsize = (m * 5, n * 5))
for idx, image in enumerate(TestSet()):
axes[(idx // m), (idx % m)].set_title('Test {}'.format(idx))
axes[(idx // m), (idx % m)].imshow(image, cmap = plt.cm.bone)
if idx == (n * m - 1):
break
test_samples()